{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3d8b830d-ee96-4d62-a8ce-da3469227cf0",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "from mindspore import ops\n",
    "from mindspore.dataset import GeneratorDataset\n",
    "import numpy as np\n",
    "from mindnlp.engine import Trainer, TrainingArguments\n",
    "from mindnlp.transformers import (\n",
    "    BloomForCausalLM, \n",
    "    BloomConfig, \n",
    "    BloomTokenizerFast,\n",
    ")\n",
    "from mindnlp.peft import LoraConfig, get_peft_model\n",
    "import gc\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "169fd2e6-7a43-46d0-8fec-64358b631661",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 设置运行模式和设备\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE, device_target=\"Ascend\")\n",
    "\n",
    "# 基本配置\n",
    "MODEL = \"bigscience/bloom-3b\"\n",
    "DATASET = \"databricks/databricks-dolly-15k\"\n",
    "TOKENS = 20\n",
    "EPOCHS = 10\n",
    "BATCH_SIZE = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1af049aa-17dd-4587-976f-b8c45de33706",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 加载模型和tokenizer\n",
    "tokenizer = BloomTokenizerFast.from_pretrained(MODEL)\n",
    "config = BloomConfig.from_pretrained(MODEL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d197447e-1381-496d-825c-739dcc86f82f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "BloomForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n",
      "  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n",
      "  - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n"
     ]
    }
   ],
   "source": [
    "# 加载基础模型\n",
    "base_model = BloomForCausalLM.from_pretrained(\n",
    "    MODEL\n",
    ")\n",
    "\n",
    "# 配置LoRA\n",
    "lora_config = LoraConfig(\n",
    "    r=16,\n",
    "    lora_alpha=16,\n",
    "    target_modules=[\"query_key_value\"],\n",
    "    lora_dropout=0.05,\n",
    "    bias=\"none\",\n",
    "    task_type=\"CAUSAL_LM\"\n",
    ")\n",
    "\n",
    "# 将基础模型包装为LoRA模型\n",
    "model = get_peft_model(base_model, lora_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e5f54117-aff8-4953-9297-5f5f6474faab",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 确保有pad_token\n",
    "if tokenizer.pad_token is None:\n",
    "    tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7f7d4d94-5be7-4e0a-9ff0-e286b0c8af5f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 加载数据集\n",
    "dataset = load_dataset(DATASET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9ef3eec9-664d-4d61-b9a1-95f4da14a12e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def format_prompt(sample):\n",
    "    instruction = f\"### Instruction\\n{sample['instruction']}\"\n",
    "    context = f\"### Context\\n{sample['context']}\" if len(sample[\"context\"]) > 0 else None\n",
    "    response = f\"### Answer\\n{sample['response']}\"\n",
    "    prompt = \"\\n\\n\".join([i for i in [instruction, context, response] if i is not None])\n",
    "    sample[\"prompt\"] = prompt\n",
    "    return sample\n",
    "\n",
    "# 处理数据集\n",
    "dataset = dataset.map(format_prompt)\n",
    "dataset = dataset.remove_columns(['instruction', 'context', 'response', 'category'])\n",
    "\n",
    "train_samples = dataset[\"train\"].select(range(0,40))\n",
    "eval_samples = dataset[\"train\"].select(range(40,50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9b9e6d98-f3e5-4536-8f67-b06a25e92277",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class TextDataset:\n",
    "    def __init__(self, data):\n",
    "        self.data = data\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        index = int(index)\n",
    "        text = self.data[index][\"prompt\"]\n",
    "        inputs = tokenizer(text, padding='max_length', max_length=256, truncation=True)\n",
    "        \n",
    "        # 使用相同的input_ids作为labels\n",
    "        return (\n",
    "            inputs[\"input_ids\"], \n",
    "            inputs[\"attention_mask\"],\n",
    "            inputs[\"input_ids\"]  # 添加labels\n",
    "        )\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b2575d0e-951f-4224-bf9b-53915cdcafd2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 创建数据集\n",
    "train_dataset = GeneratorDataset(\n",
    "    TextDataset(train_samples),\n",
    "    column_names=[\"input_ids\", \"attention_mask\", \"labels\"],  # 添加labels\n",
    "    shuffle=True\n",
    ")\n",
    "eval_dataset = GeneratorDataset(\n",
    "    TextDataset(eval_samples),\n",
    "    column_names=[\"input_ids\", \"attention_mask\", \"labels\"],  # 添加labels\n",
    "    shuffle=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "08d4d7d7-5a63-4bc2-8588-8039af71987b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 定义训练参数\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='./work',\n",
    "    auto_find_batch_size=True,\n",
    "    learning_rate=2e-4,\n",
    "    num_train_epochs=EPOCHS,\n",
    "    per_device_train_batch_size=BATCH_SIZE,\n",
    "    per_device_eval_batch_size=BATCH_SIZE,\n",
    "    logging_strategy=\"epoch\",\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    fp16=True,  # 启用混合精度训练\n",
    ")\n",
    "\n",
    "# 创建trainer\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=eval_dataset,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8fa63696-2dc4-4851-9ed1-08adb4df558c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6a5ca1ac0c784452b8b68f00e44494d3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using `past_key_values` as a tuple is deprecated. Please use an appropriate `Cache` class\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 8.5685, 'learning_rate': 0.00018, 'epoch': 1.0}\n",
      "/\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 6.447186470031738, 'eval_runtime': 8.8455, 'eval_samples_per_second': 0.339, 'eval_steps_per_second': 0.113, 'epoch': 1.0}\n",
      "{'loss': 6.7008, 'learning_rate': 0.00016, 'epoch': 2.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 5.485293388366699, 'eval_runtime': 0.4395, 'eval_samples_per_second': 6.826, 'eval_steps_per_second': 2.275, 'epoch': 2.0}\n",
      "{'loss': 4.3543, 'learning_rate': 0.00014, 'epoch': 3.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 3.417534351348877, 'eval_runtime': 0.474, 'eval_samples_per_second': 6.329, 'eval_steps_per_second': 2.11, 'epoch': 3.0}\n",
      "{'loss': 2.5833, 'learning_rate': 0.00012, 'epoch': 4.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.719986915588379, 'eval_runtime': 0.4308, 'eval_samples_per_second': 6.964, 'eval_steps_per_second': 2.321, 'epoch': 4.0}\n",
      "{'loss': 1.7503, 'learning_rate': 0.0001, 'epoch': 5.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.5181431770324707, 'eval_runtime': 0.4391, 'eval_samples_per_second': 6.832, 'eval_steps_per_second': 2.277, 'epoch': 5.0}\n",
      "{'loss': 1.4438, 'learning_rate': 8e-05, 'epoch': 6.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.516507148742676, 'eval_runtime': 0.4309, 'eval_samples_per_second': 6.962, 'eval_steps_per_second': 2.321, 'epoch': 6.0}\n",
      "{'loss': 1.3124, 'learning_rate': 6e-05, 'epoch': 7.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.4947118759155273, 'eval_runtime': 0.4491, 'eval_samples_per_second': 6.68, 'eval_steps_per_second': 2.227, 'epoch': 7.0}\n",
      "{'loss': 1.2436, 'learning_rate': 4e-05, 'epoch': 8.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.664604425430298, 'eval_runtime': 0.4324, 'eval_samples_per_second': 6.937, 'eval_steps_per_second': 2.312, 'epoch': 8.0}\n",
      "{'loss': 1.205, 'learning_rate': 2e-05, 'epoch': 9.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.468752145767212, 'eval_runtime': 0.4321, 'eval_samples_per_second': 6.943, 'eval_steps_per_second': 2.314, 'epoch': 9.0}\n",
      "{'loss': 1.1878, 'learning_rate': 0.0, 'epoch': 10.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.4858415126800537, 'eval_runtime': 0.4329, 'eval_samples_per_second': 6.93, 'eval_steps_per_second': 2.31, 'epoch': 10.0}\n",
      "{'train_runtime': 132.5653, 'train_samples_per_second': 3.017, 'train_steps_per_second': 0.754, 'train_loss': 3.0349820709228514, 'epoch': 10.0}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=100, training_loss=3.0349820709228514, metrics={'train_runtime': 132.5653, 'train_samples_per_second': 3.017, 'train_steps_per_second': 0.754, 'train_loss': 3.0349820709228514, 'epoch': 10.0})"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 训练模型\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "39203b95-f3b4-4b06-b8ed-a9f58f227799",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "log_history = trainer.state.log_history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d913caf3-c179-4b11-832a-062f5790d354",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 绘制损失曲线\n",
    "training_losses = [log[\"loss\"] for log in log_history if \"loss\" in log]\n",
    "validation_losses = [log[\"eval_loss\"] for log in log_history if \"eval_loss\" in log]\n",
    "epochs = range(1, len(training_losses) + 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8d57f6a8-f7ea-498f-8a37-6dbdfe696a0a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(epochs, training_losses, 'b-', label='Training Loss')\n",
    "plt.plot(epochs, validation_losses, 'r-', label='Validation Loss')\n",
    "plt.title('Training and Validation Loss')\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "plt.savefig('training_loss.png')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d2106841-3687-449b-995a-937181b3f197",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 生成文本函数\n",
    "def generate_text(input_text):\n",
    "    input_tokens = tokenizer(input_text, return_tensors=\"ms\")\n",
    "    outputs = model.generate(\n",
    "        input_ids=input_tokens[\"input_ids\"],\n",
    "        attention_mask=input_tokens[\"attention_mask\"],\n",
    "        max_new_tokens=TOKENS,\n",
    "        repetition_penalty=1.5,\n",
    "        eos_token_id=tokenizer.eos_token_id\n",
    "    )\n",
    "    return tokenizer.batch_decode(outputs, skip_special_tokens=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "cd8869e0-3d34-48ef-8fa1-88c0ffc4a8dd",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['What is the meaning of life? What does it mean to be human and what makes us special from other animals, plants or even rocks']\n"
     ]
    }
   ],
   "source": [
    "input_words = \"What is the meaning of life?\"\n",
    "generated_text = generate_text(input_words)\n",
    "print(generated_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "425d5269-4168-4e8d-92cc-8d3299ed981a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 保存模型\n",
    "trainer.save_model(\"./lora_weights\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f9b330de-95eb-40c8-9e95-bbeb16e96260",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4191"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 清理内存\n",
    "del model\n",
    "del trainer\n",
    "del train_dataset\n",
    "del eval_samples\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc19775d-1f5a-470d-83ba-1ad6f652fce4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
